import tkinter as tk
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import sounddevice as sd
import socket
import json
from scipy.signal import butter, sosfilt

# --------------------------
# Parameters
# --------------------------
N = 500            # Lattice points
BANDS = 8          # Multi-band split
dt = 0.05
audio_scale = 0.01
env_k = 2.0
size_base = 20     # Base scatter size

# --------------------------
# Tkinter + figure
# --------------------------
root = tk.Tk()
root.title("HDGL Analog 1:1 Multi-band Lattice (Amplitude & Color)")

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
canvas = FigureCanvasTkAgg(fig, master=root)
canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

# --------------------------
# Sliders
# --------------------------
slider_frame = tk.Frame(root)
slider_frame.pack(side=tk.BOTTOM, fill=tk.X)

def make_slider(label, minv, maxv, default):
    tk.Label(slider_frame, text=label).pack(side=tk.LEFT)
    var = tk.DoubleVar(value=default)
    slider = tk.Scale(slider_frame, from_=minv, to=maxv, resolution=0.01,
                      orient=tk.HORIZONTAL, variable=var)
    slider.pack(side=tk.LEFT)
    return var

morph_var = make_slider("Morph (Polar→Cartesian)", 0, 1, 0)
ampl_var  = make_slider("Amplitude Scale", 0, 2, 1)

# --------------------------
# Lattice setup
# --------------------------
phi = (1 + np.sqrt(5)) / 2
theta = 2 * np.pi / phi
radii = np.sqrt(np.arange(N))
angles = np.arange(N) * theta
zs = np.linspace(-1, 1, N)

def soft_envelope(signal, k=2.0):
    return np.tanh(k * signal)

# --------------------------
# Audio input
# --------------------------
audio_buffer = np.zeros(1024)
def audio_callback(indata, frames, time, status):
    global audio_buffer
    audio_buffer = indata[:,0]

stream = sd.InputStream(callback=audio_callback, channels=1, samplerate=44100)
stream.start()

# --------------------------
# Wi-Fi EMF sensor (UDP)
# --------------------------
UDP_IP = "0.0.0.0"
UDP_PORT = 5005
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind((UDP_IP, UDP_PORT))
sock.setblocking(False)

def get_wifi_emf():
    try:
        data, _ = sock.recvfrom(1024)
        payload = json.loads(data.decode())
        return float(payload.get("emf", 0.0))
    except:
        return 0.0

# --------------------------
# Multi-band filter
# --------------------------
def band_split(signal, bands=BANDS, fs=44100):
    band_signals = []
    nyq = 0.5 * fs
    freqs = np.logspace(np.log10(20), np.log10(20000), bands+1)
    for i in range(bands):
        low = freqs[i] / nyq
        high = freqs[i+1] / nyq
        sos = butter(2, [low, high], btype='band', output='sos')
        filtered = sosfilt(sos, signal)
        band_signals.append(filtered)
    return band_signals

# --------------------------
# Scatter plot setup
# --------------------------
scat = ax.scatter([], [], [], c=[], cmap='viridis', s=size_base)

# --------------------------
# Lattice mapping
# --------------------------
def get_lattice_multiband(t=0, audio_signal=None, emf_signal=0):
    if audio_signal is None:
        audio_signal = np.zeros(1024)
    bands = band_split(audio_signal, bands=BANDS)
    x, y, z, c, s = np.zeros(N), np.zeros(N), np.zeros(N), np.zeros(N), np.zeros(N)
    points_per_band = N // BANDS
    for i, band in enumerate(bands):
        mod = soft_envelope(np.abs(band) + emf_signal, k=env_k)
        start = i*points_per_band
        end = start + points_per_band
        r = radii[start:end] + mod[:points_per_band]*audio_scale
        x[start:end] = r * np.cos(angles[start:end])
        y[start:end] = r * np.sin(angles[start:end])
        z[start:end] = zs[start:end] * ampl_var.get()
        # Morphing
        morph = morph_var.get()
        x[start:end] = x[start:end]*(1-morph) + np.linspace(-1,1,points_per_band)*morph
        y[start:end] = y[start:end]*(1-morph) + np.linspace(-1,1,points_per_band)*morph
        z[start:end] = z[start:end]*(1-morph) + np.linspace(-1,1,points_per_band)*morph
        # Color & size encoding
        c[start:end] = mod[:points_per_band]
        s[start:end] = size_base + mod[:points_per_band]*50
    return x, y, z, c, s

# --------------------------
# Animation update
# --------------------------
def update(frame):
    env_mod = get_wifi_emf()
    x, y, z, c, s = get_lattice_multiband(t=frame*dt, audio_signal=audio_buffer, emf_signal=env_mod)
    scat._offsets3d = (x, y, z)
    scat.set_array(c)
    scat.set_sizes(s)
    ax.set_xlim(-5,5)
    ax.set_ylim(-5,5)
    ax.set_zlim(-2,2)
    return scat,

ani = FuncAnimation(fig, update, interval=dt*1000, blit=False)
root.mainloop()
